---
id: image-classification
title: Image Classification
---

import ExpoSnack from '@site/src/components/ExpoSnack';

In this tutorial, you will build an app that can take pictures and classify objects in each image using an on-device image classification model.

## Viewing this Demo

In order to view this demo [download the PlayTorch app](/docs/tutorials/get-started.mdx#download-the-playtorch-app).


## Preview

If you want a sneak peek at what you'll be building, run this Snack by scanning the QR code in the PlayTorch app!

<ExpoSnack snackId="@playtorch/image-classification" />

## Overview

We'll go through the following steps:

1. Create a new project with Snack by Expo
2. Run the project in the PlayTorch app
3. Add PlayTorch dependencies
4. Add a camera view
5. Process an image
6. Run a model
7. Display results

## Starting a New Project

We will be using a tool called Snack by Expo to write our code in the browser and then run it on our device. To learn more about Snack, visit [this link](https://docs.expo.dev/workflow/snack/).

Open a new tab in your web browser and navigate to [snack.expo.dev](https://snack.expo.dev).

You will see a code editor with the `App.js` file open. On the right side of the window, you will see several options for running your code. It defaults to "Web", but let's select "My Device" so we can use the PlayTorch app to enable ML in our app.

![](/img/tutorials/snacks/image-classification/starter-snack.png)

## Run the New Project

Open the PlayTorch app and from the home screen, tap "Scan QR Code".

If you have never done this before, it will ask for camera permissions. Grant the app camera permissions and scan the QR code from the right side of the Snack window.

If you haven't made any changes to the snack, you should see a screen that looks like this:

![](/img/tutorials/snacks/image-classification/starter-screen.jpg)

Try changing the `backgroundColor` to `#800080` on line 29 and you will see that your app screen changes in real time to match it.

## Add PlayTorch Dependencies

In order to add ML to this simple demo, we first need to add the PlayTorch dependencies.

In the left side of the Snack window, you will see a list of the files being used in your Snack. Open the one called `package.json` and replace the contents with the following:

```json
{
  "dependencies": {
    "react-native-pytorch-core": "0.2.0",
    "react-native-safe-area-context": "3.3.2"
  }
}
```

This is a list of external libraries that we will be using to build our ML powered demo.

## Add a Camera View

Now that we have the extra dependencies loaded, we can use them to prepare our user interface for collecting images to classify.

Go ahead and replace the contents of `App.js` with the following. Let's walk through what changes are included:

1. Import dependencies. It's worth noting we import the Camera component from the `react-native-pytorch-core` package. That is the core PlayTorch SDK
2. Update the App function to render our new UI
   1. Get the "safe area insets" which let us know how much of the screen we can actually use to render avoiding camera notches and bottom bars.
   2. Make the Camera view fill the whole screen except for the unsafe areas on the bottom so the capture button doesn't get obscured
   3. Create a label container for when we begin classifying images that floats near the top
3. Create a styles object that is used to set the styles for or label container

```js title="App.js"
// 1. Import dependencies
import * as React from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Camera} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

// 2. App function to render a camera and a text
export default function App() {
  // 2.i. Safe area insets to compensate for notches and bottom bars
  const insets = useSafeAreaInsets();
  return (
    <View style={StyleSheet.absoluteFill}>
      {/* 2.ii. Render camara and make it parent filling */}
      <Camera style={[StyleSheet.absoluteFill, {bottom: insets.bottom}]} />
      {/* 2.iii. Label container with custom render style and a text */}
      <View style={styles.labelContainer}>
        <Text>Label will go here</Text>
      </View>
    </View>
  );
}

// 3. Custom render style for label container
const styles = StyleSheet.create({
  labelContainer: {
    padding: 20,
    margin: 20,
    marginTop: 40,
    borderRadius: 10,
    backgroundColor: 'white',
  },
});
```

Once you make these changes, open the Snack back up in the PlayTorch app and you will see the camera view filling the screen and our label container with a placeholder label.

Notice clicking the capture button doesn't do anything yet. Let's fix that.

The added lines below do the following:

1. Create an async (runs in the background) function called `handleImage` that simply:
   1. Logs the image object passed to it
   2. Releases the image from memory. Not calling `image.release()` will result in the camera not providing an image on consecutive presses of the capture button. The only way to fix this is to force close the PlayTorch app and reopen it.
2. Set the `handleImage` function to be called everytime an image is captured by the `Camera` component

```js {12-18,25-26} title="App.js"
// Import dependencies
import * as React from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Camera} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';

// App function to render a camera and a text
export default function App() {
  // Safe area insets to compensate for notches and bottom bars
  const insets = useSafeAreaInsets();

  // 1. Function to handle images whenever the user presses the capture button
  async function handleImage(image) {
    // 1.i. Log the image object to the console
    console.log(image);
    // 1.ii. Release the image from memory
    image.release();
  }

  return (
    <View style={StyleSheet.absoluteFill}>
      {/* Render camara and make it parent filling */}
      <Camera
        style={[StyleSheet.absoluteFill, {bottom: insets.bottom}]}
        // 2. Add handle image callback on the camera component
        onCapture={handleImage}
      />
      {/* Label container with custom render style and a text */}
      <View style={styles.labelContainer}>
        <Text>Label will go here</Text>
      </View>
    </View>
  );
}

// Custom render style for label container
const styles = StyleSheet.create({
  labelContainer: {
    padding: 20,
    margin: 20,
    marginTop: 40,
    borderRadius: 10,
    backgroundColor: 'white',
  },
});
```

Open the logs in the Snack window by clicking the settings gear icon at the bottom of the window, enabling the Panel, and clicking the logs tab of the newly opened panel.

![](/img/tutorials/snacks/image-classification/snack-open-logs.png)

After taking a picture, you should see a logged object with an `ID` field.

Now that we can capture images, let's write some code to prepare them for machine learning!

## Process an Image

In order for us to run machine learning on our image to classify it, we first need to translate it to a format that the ML model understands.

ML models don't work with images, but with tensors (multi dimensional matrices) with their own specific data format.

The MobileNet model that we will be using only needs its image data in its tensors to be exactly 224 by 224 with normalized data and a specific shape.

Let's create a new file by clicking the new file button in the left pane of the Snack window. We'll call it `ImageClassifier.js`

Let's walk through the code below to see how we get our image converted to a proper tensor:

1. Import `torch`, `torchvision`, and `media` from the `react-native-pytorch-core` package (the PlayTorch SDK)
2. Create an alias called `T` for the `transforms` object from the `torchvision` to make it shorter to access the transform functions
3. Create an async function called `classifyImage` that takes in an image and does the following:
   1. Grab the `width` and the `height` of the image
   2. Create a `blob` of the image (a blob is just a raw data representation of something). In this case, the blob holds a byte representation of the image in the format height, width, and channels, or HWC for short.
   3. Create a `tensor` from the blob with the shape height by width by channels (RGB). It is important that the order of HWC is aligned with the byte representation of the image.
   4. Rearrange the tensor shape to be channels (RGB) by height by width. This is important because the image classifiation model that is used in this tutorial requires the tensor shape to be in this order.
   5. Divide all of the values in the tensor by 255. This is important because the image classification model requires the tensor values to be between `[0, 1]`.
   6. Center crop the image data within the tensor. The center crop will result in a squared image tensor with the shortest side defining the size.
   7. Resize the tensor to `3` by `224` by `224` (or tensor shape `[3, 224, 224]`) to match the size the model expects as tensor input format.
   8. Normalize the tensor image with mean and standard deviation.
   9. Add one more leading dimension to the tensor to be in the shape `1` by `3` by `224` by `224` (or tensor shape `[1, 3, 224, 224]`). The image classification model can classify multiple images in parallel. The leading `1` represents the batch size, which is `1` because in this tutorial it only needs to process one image at a time.
   10. Return the shape of the tensor, which is `[1, 3, 224, 224]`.

```js title="ImageClassifier.js"
// 1. Import torch, torchvision, and media from PlayTorch SDK
import {torch, torchvision, media} from 'react-native-pytorch-core';

// 2. Alias for torchvision transforms
const T = torchvision.transforms;

// 3. The classifyImage function that will process an image and return the top
// class label
export default async function classifyImage(image) {
  // 3.i. Get image width and height
  const width = image.getWidth();
  const height = image.getHeight();

  // 3.ii. Convert image to blob, which is a byte representation of the image
  // in the format height (H), width (W), and channels (C), or HWC for short
  const blob = media.toBlob(image);

  // 3.iii. Get a tensor from image the blob and also define in what format
  // the image blob is.
  let tensor = torch.fromBlob(blob, [height, width, 3]);

  // 3.iv. Rearrange the tensor shape to be [CHW]
  tensor = tensor.permute([2, 0, 1]);

  // 3.v. Divide the tensor values by 255 to get values between [0, 1]
  tensor = tensor.div(255);

  // 3.vi. Crop the image in the center to be a squared image
  const centerCrop = T.centerCrop(Math.min(width, height));
  tensor = centerCrop(tensor);

  // 3.vii. Resize the image tensor to 3 x 224 x 224
  const resize = T.resize(224);
  tensor = resize(tensor);

  // 3.viii. Normalize the tensor image with mean and standard deviation
  const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
  tensor = normalize(tensor);

  // 3.ix. Unsqueeze adds 1 leading dimension to the tensor
  tensor = tensor.unsqueeze(0);

  // 3.x. Return the tensor shape [1, 3, 224, 224]
  return tensor.shape;
}
```

Let's double check the output of this function to make sure we are on the right track.

Go back to `App.js` and instead of just logging the `image` object, let's run the `classifyImage` function on the `image` object first and log the result instead.

1. Import the `classifyImage` function from the `ImageClassifier.js` file.
2. Call `classifyImage` function with the `image` from the camera.
3. Log the `result` to the console.

```js {6-7,16-17,18-19} title="App.js"
// Import dependencies
import * as React from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Camera} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
// 1. Import classify image function
import classifyImage from './ImageClassifier';

// App function to render a camera and a text
export default function App() {
  // Safe area insets to compensate for notches and bottom bars
  const insets = useSafeAreaInsets();

  // Function to handle images whenever the user presses the capture button
  async function handleImage(image) {
    // 2. Call the classify image function with the camera image
    const result = await classifyImage(image);
    // 3. Log the result from classify image to the console
    console.log(result);
    // Release the image from memory
    image.release();
  }

  return (
    <View style={StyleSheet.absoluteFill}>
      {/* Render camara and make it parent filling */}
      <Camera
        style={[StyleSheet.absoluteFill, {bottom: insets.bottom}]}
        // Add handle image callback on the camera component
        onCapture={handleImage}
      />
      {/* Label container with custom render style and a text */}
      <View style={styles.labelContainer}>
        <Text>Label will go here</Text>
      </View>
    </View>
  );
}

// Custom render style for label container
const styles = StyleSheet.create({
  labelContainer: {
    padding: 20,
    margin: 20,
    marginTop: 40,
    borderRadius: 10,
    backgroundColor: 'white',
  },
});
```

When you check the log ouput after capturing an image now, you should see `[1,3,224,224]`, which is the tensor shape we need.

Now that the image has been converted to a properly formatted tensor, we are ready to run the machine learning model!

## Run the Model

Let's head back to our `ImageClassifier.js` file and make some updates to the `classifyImage` function to actually classify the image.

For the changes we make to `ImageClassifer.js` you'll need to upload a file containing the labels for the different things the model knows how to classify.

[Click here to do download the JSON file](https://github.com/facebookresearch/playtorch/releases/download/v0.1.0/ImageNetClasses.json) and then drag and drop it into the Snack window to upload it.

Here's a quick summary of the changes we are making to run the model:

1. Import the `MobileModel` to help us load our machine learning model
2. Import the class labels from the `ImageNetClasses.json` file. The JSON file is a mapping between image class indices to class labels.
3. Store the url for the model we'll be using in a variable for later access
4. Create a variable for storing our `model` and set it to null initially
5. After we have the tensor all ready, check to see if our `model` is still null. If it is, initialize it by downloading it and loading it into memory.
6. Run the model on our image converted into a tensor by calling `model.forward(tensor)`. The return value will be a `Tensor` of shape `[1, 1000]` where `1` the batch size (remember in this tutorial only 1 image is processed at a time) and `1000` are 1000 probability values (one probability value for each class in the `ImageNetClasses.json`).
7. Find the index with the highest probability, which represents the most likely class detected in the image.
8. Resolve the most likely image class by mapping the index to the class label and return it.

```js {1,3,8-10,15-17,19-20,58-72} title="ImageClassifer.js"
// 1. Add import for MobileModel from PlayTorch SDK
import {
  MobileModel,
  torch,
  torchvision,
  media,
} from 'react-native-pytorch-core';
// 2. Import the ImageNetClasses JSON file, which is used below to map the
// processed model result to a class label
import * as ImageNetClasses from './ImageNetClasses.json';

// Alias for torchvision transforms
const T = torchvision.transforms;

// 3. URL to the image classification model that is used in this example
const MODEL_URL =
  'https://github.com/facebookresearch/playtorch/releases/download/v0.1.0/mobilenet_v3_small.ptl';

// 4. Variable to hold a reference to the loaded ML model
let model = null;

// The classifyImage function that will process an image and return the top
// class label
export default async function classifyImage(image) {
  // Get image width and height
  const width = image.getWidth();
  const height = image.getHeight();

  // Convert image to blob, which is a byte representation of the image
  // in the format height (H), width (W), and channels (C), or HWC for short
  const blob = media.toBlob(image);

  // Get a tensor from image the blob and also define in what format
  // the image blob is.
  let tensor = torch.fromBlob(blob, [height, width, 3]);

  // Rearrange the tensor shape to be [CHW]
  tensor = tensor.permute([2, 0, 1]);

  // Divide the tensor values by 255 to get values between [0, 1]
  tensor = tensor.div(255);

  // Crop the image in the center to be a squared image
  const centerCrop = T.centerCrop(Math.min(width, height));
  tensor = centerCrop(tensor);

  // Resize the image tensor to 3 x 224 x 224
  const resize = T.resize(224);
  tensor = resize(tensor);

  // Normalize the tensor image with mean and standard deviation
  const normalize = T.normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]);
  tensor = normalize(tensor);

  // Unsqueeze adds 1 leading dimension to the tensor
  tensor = tensor.unsqueeze(0);

  // 5. If the model has not been loaded already, it will be downloaded from
  // the URL and then loaded into memory.
  if (model == null) {
    const filePath = await MobileModel.download(MODEL_URL);
    model = await torch.jit._loadForMobile(filePath);
  }

  // 6. Run the ML inference with the pre-processed image tensor
  const output = await model.forward(tensor);

  // 7. Get the index of the value with the highest probability
  const maxIdx = output.argmax().item();

  // 8. Resolve the most likely class label and return it
  return ImageNetClasses[maxIdx];
}
```

:::note

Since we are initializing the model the first time we run the `classifyImage` function, it will be slower. Subsequent runs will go much faster since they don't have to download the model or load it into memory.

If you do not wish to upload your model to a publicly accessible server, you may instead place the file in a directory of your choice and replace the line `const filePath = await MobileModel.download(MODEL_URL);` with `const filePath = await MobileModel.download(require('./path/to/model.ptl'));`.

:::

Now that we are actually running the model, let's try it out in the PlayTorch app again and see what it logs. You should see a class label in the logs which is a word or list of words.

Excellent! It's logging the classification of each picture!

## Display the Result

Lastly, let's update our UI to display the result of our model!

Go back to `App.js` and make the following changes:

1. Create a state variable to store the `topClass` we get from the model
2. In the `handleImage` function, set the `topClass` state variable to the result of the `classifyImage` function
3. Change the text in the UI to display the `topClass` state variable instead of the placeholder text

```js {13-17,23-24,39-40} title="App.js"
// Import dependencies
import * as React from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Camera} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
// Import classify image function
import classifyImage from './ImageClassifier';

// App function to render a camera and a text
export default function App() {
  // Safe area insets to compensate for notches and bottom bars
  const insets = useSafeAreaInsets();
  // 1. Create a React state to store the top class returned from the
  // classifyImage function
  const [topClass, setTopClass] = React.useState(
    "Press capture button to classify what's in the camera view!",
  );

  // Function to handle images whenever the user presses the capture button
  async function handleImage(image) {
    // Call the classify image function with the camera image
    const result = await classifyImage(image);
    // 2. Set result as top class label state
    setTopClass(result);
    // Release the image from memory
    image.release();
  }

  return (
    <View style={StyleSheet.absoluteFill}>
      {/* Render camara and make it parent filling */}
      <Camera
        style={[StyleSheet.absoluteFill, {bottom: insets.bottom}]}
        // Add handle image callback on the camera component
        onCapture={handleImage}
      />
      {/* Label container with custom render style and a text */}
      <View style={styles.labelContainer}>
        {/* 3. Change the text to render the top class label */}
        <Text>{topClass}</Text>
      </View>
    </View>
  );
}

// Custom render style for label container
const styles = StyleSheet.create({
  labelContainer: {
    padding: 20,
    margin: 20,
    marginTop: 40,
    borderRadius: 10,
    backgroundColor: 'white',
  },
});
```

And with that you have a working image classifer!
